{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# L4: Multimodal Retrieval Augmented Generation (MM-RAG)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p style=\"background-color:#fff6e4; padding:15px; border-width:3px; border-color:#f5ecda; border-style:solid; border-radius:6px\"> ⏳ <b>Note <code>(Kernel Starting)</code>:</b> This notebook takes about 30 seconds to be ready to use. You may start and watch the video while you wait.</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">In this lesson you'll learn how to leverage Weaviate and Google Gemini Pro Vision to carry out a simple multimodal RAG workflow."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* In this classroom, the libraries have been already installed for you.\n",
    "* If you would like to run this code on your own machine, you need to install the following:\n",
    "```\n",
    "    !pip install -U weaviate-client\n",
    "    !pip install google-generativeai\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "### Load environment variables and API keys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 115
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from dotenv import load_dotenv, find_dotenv\n",
    "_ = load_dotenv(find_dotenv()) # read local .env file\n",
    "\n",
    "EMBEDDING_API_KEY = os.getenv(\"EMBEDDING_API_KEY\")\n",
    "GOOGLE_API_KEY=os.getenv(\"GOOGLE_API_KEY\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Note: learn more about [GOOGLE_API_KEY](https://ai.google.dev/) to run it locally."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Connect to Weaviate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 251
   },
   "outputs": [],
   "source": [
    "import weaviate\n",
    "\n",
    "client = weaviate.connect_to_embedded(\n",
    "    version=\"1.24.4\",\n",
    "    environment_variables={\n",
    "        \"ENABLE_MODULES\": \"backup-filesystem,multi2vec-palm\",\n",
    "        \"BACKUP_FILESYSTEM_PATH\": \"/home/jovyan/work/backups\",\n",
    "    },\n",
    "    headers={\n",
    "        \"X-PALM-Api-Key\": EMBEDDING_API_KEY,\n",
    "    }\n",
    ")\n",
    "\n",
    "client.is_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Restore 13k+ prevectorized resources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 183
   },
   "outputs": [],
   "source": [
    "client.backup.restore(\n",
    "    backup_id=\"resources-img-and-vid\",\n",
    "    include_collections=\"Resources\",\n",
    "    backend=\"filesystem\"\n",
    ")\n",
    "\n",
    "# It can take a few seconds for the \"Resources\" collection to be ready.\n",
    "# We add 5 seconds of sleep to make sure it is ready for the next cells to use.\n",
    "import time\n",
    "time.sleep(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preview data count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 200
   },
   "outputs": [],
   "source": [
    "from weaviate.classes.aggregate import GroupByAggregate\n",
    "\n",
    "resources = client.collections.get(\"Resources\")\n",
    "\n",
    "response = resources.aggregate.over_all(\n",
    "    group_by=GroupByAggregate(prop=\"mediaType\")\n",
    ")\n",
    "\n",
    "# print rounds names and the count for each\n",
    "for group in response.groups:\n",
    "    print(f\"{group.grouped_by.value} count: {group.total_count}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multimodal RAG"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 1 – Retrieve content from the database with a query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 268
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image\n",
    "from weaviate.classes.query import Filter\n",
    "\n",
    "def retrieve_image(query):\n",
    "    resources = client.collections.get(\"Resources\")\n",
    "# ============\n",
    "    response = resources.query.near_text(\n",
    "        query=query,\n",
    "        filters=Filter.by_property(\"mediaType\").equal(\"image\"), # only return image objects\n",
    "        return_properties=[\"path\"],\n",
    "        limit = 1,\n",
    "    )\n",
    "# ============\n",
    "    result = response.objects[0].properties\n",
    "    return result[\"path\"] # Get the image path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run image retrieval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "# Try with different queries to retreive an image\n",
    "img_path = retrieve_image(\"fishing with my buddies\")\n",
    "display(Image(img_path))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p style=\"background-color:#fff6ff; padding:15px; border-width:3px; border-color:#efe6ef; border-style:solid; border-radius:6px\"> 💻 &nbsp; <b>Access Files and Helper Functions:</b> To access the files for this notebook, 1) click on the <em>\"File\"</em> option on the top menu of the notebook and then 2) click on <em>\"Open\"</em>. For more help, please see the <em>\"Appendix - Tips and Help\"</em> Lesson.</p>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 2 - Generate a description of the image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 212
   },
   "outputs": [],
   "source": [
    "import google.generativeai as genai\n",
    "from google.api_core.client_options import ClientOptions\n",
    "\n",
    "# Set the Vision model key\n",
    "genai.configure(\n",
    "        api_key=GOOGLE_API_KEY,\n",
    "        transport=\"rest\",\n",
    "        client_options=ClientOptions(\n",
    "            api_endpoint=os.getenv(\"GOOGLE_API_BASE\"),\n",
    "        ),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 302
   },
   "outputs": [],
   "source": [
    "# Helper function\n",
    "import textwrap\n",
    "import PIL.Image\n",
    "from IPython.display import Markdown, Image\n",
    "\n",
    "def to_markdown(text):\n",
    "    text = text.replace(\"•\", \"  *\")\n",
    "    return Markdown(textwrap.indent(text, \"> \", predicate=lambda _: True))\n",
    "\n",
    "def call_LMM(image_path: str, prompt: str) -> str:\n",
    "    img = PIL.Image.open(image_path)\n",
    "\n",
    "    model = genai.GenerativeModel(\"gemini-pro-vision\")\n",
    "    response = model.generate_content([prompt, img], stream=True)\n",
    "    response.resolve()\n",
    "\n",
    "    return to_markdown(response.text)    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run vision request"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "call_LMM(img_path, \"Please describe this image in detail.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Note: Please be aware that the output from the previous cell may differ from what is shown in the video. This variation is normal and should not cause concern."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## All together"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 166
   },
   "outputs": [],
   "source": [
    "def mm_rag(query):\n",
    "    # Step 1 - retrieve an image – Weaviate\n",
    "    SOURCE_IMAGE = retrieve_image(query)\n",
    "    display(Image(SOURCE_IMAGE))\n",
    "#===========\n",
    "\n",
    "    # Step 2 - generate a description - GPT4\n",
    "    description = call_LMM(SOURCE_IMAGE, \"Please describe this image in detail.\")\n",
    "    return description"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "# Call mm_rag function\n",
    "mm_rag(\"paragliding through the mountains\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "# Remember to close the weaviate instance\n",
    "client.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Try it yourself! \n",
    "\n",
    "Run the cells above selecting another image from the database and generate a description for it!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
